import os
import argparse
import glob
import torch
import numpy as np
#import clip
from PIL import Image
#import open_clip
#import clip
from embedding_projection import ImageEmbeddingProjector, EmbeddingProjectionVisualizer
from torchvision import transforms
from models.FinetuneVTmodels import *
from models.MIL_VT import *
from models.FinetuneVTmodels import MIL_VT_FineTune

def main():
    parser = argparse.ArgumentParser(description="Image Embedding Processing and Visualization")
    parser.add_argument("--image_dir1", type=str, required=True, help="Directory containing original image files")
    parser.add_argument("--image_dir2", type=str, required=True, help="Directory containing image files for projection")
    parser.add_argument("--output_path", type=str, required=True, help="Output path for visualizations")
    args = parser.parse_args()

    device = "cuda" if torch.cuda.is_available() else "cpu"
    # model, _, preprocess = open_clip.create_model_and_transforms('ViT-H-14', pretrained='laion2b_s32b_b79k', device=device)
    #model, preprocess = clip.load("ViT-L/14", device=device)
    model =MIL_VT_FineTune()
    checkpoint_path = '/path/to/saved/weight/file'
    checkpoint = torch.load(checkpoint_path, map_location=torch.device('cpu'))
    model.load_state_dict(checkpoint['state_dict'])
    model.eval()
    model.to(device)


    
    preprocess = transforms.Compose([
    transforms.Resize((384, 384)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
])

    # Create an instance of the ImageEmbeddingProjector
    projector = ImageEmbeddingProjector(model, preprocess, device=device)

    # Load and process images from a directory, fit PCA and get the basis vectors
    image_paths1 = glob.glob(os.path.join(args.image_dir1, '*.png'))
    image_paths1.sort()
    emb_map, emb_matrix = projector.load_and_embed_images(image_paths1)
    projector.fit_pca(emb_matrix)

    # Project the embeddings of considered images onto the principal components and store the projections
    image_paths2 = glob.glob(os.path.join(args.image_dir2, '*.png'))
    image_paths2.sort()
    emb_map, emb_matrix = projector.load_and_embed_images(image_paths2)
    emb_projection_map = projector.project_embeddings(emb_matrix, emb_map, image_paths2)
    
    
    img_title_map = {
    image_paths2[0]: "no_DR",
    image_paths2[1]: "moderate",
    image_paths2[2]: "severe",
    image_paths2[3]: "no_DR -> severe",
    image_paths2[4]: "moderate -> severe",
    image_paths2[5]: "severe -> no_DR"
 
}

    
    texts =["No DR","Moderate","Severe"]

    
    mat_org = np.array([[9.9889e-01, 6.0757e-04, 1.9922e-04],
        [ 1.4922e-03, 9.9500e-01, 5.2162e-04],
        [ 4.8594e-04, 2.4362e-03, 9.9583e-01]])
       

    mat_final = np.array([[ 2.5102e-03, 7.3896e-03, 9.8772e-01 ],
        [ 5.4882e-04, 5.5881e-01, 9.4214e-01],
        [9.9859e-01, 5.1972e-04, 2.2735e-04]])
      
    
    # Create an instance of the EmbeddingProjectionVisualizer
    visualizer = EmbeddingProjectionVisualizer(emb_projection_map, img_title_map, mat_org, mat_final, texts, args.output_path)
    # Plot the projections and matrices
    visualizer.plot_projections_and_matrices()

if __name__ == "__main__":
    main()
